{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5106742d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T06:53:56.418154Z",
     "start_time": "2023-06-18T06:53:56.412063Z"
    }
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c16de51d",
   "metadata": {},
   "source": [
    "## TrainingArguments"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4daad011",
   "metadata": {},
   "source": [
    "- 维护统一命名的训练参数\n",
    "\n",
    "    - `output_dir`\n",
    "    - `num_train_epochs`\n",
    "    - `evaluation_strategy`\n",
    "        - `epoch`\n",
    "    - `logging_steps`:\n",
    "        - 100，多少个optimizer steps 显示 loss及其他 metrics\n",
    "    - `per_device_train_batch_size`\n",
    "    - `per_device_eval_batch_size`\n",
    "    - `save_strategy`\n",
    "        - 'epoch'\n",
    "    - 优化器相关\n",
    "        - `learning_rate`\n",
    "        - `alpha`\n",
    "        - `weight_decay`\n",
    "        - `optim`\n",
    "            - 'adamw-torch'\n",
    "        - lr scheduler\n",
    "            - `lr_scheduler_type=\"linear\"`,\n",
    "                - linear\n",
    "                - cosine\n",
    "            - `warmup_ratio=0.1`,\n",
    "    - 精度量化\n",
    "        - `fp16`: `True`\n",
    "    - `push_to_hub`\n",
    "        - `True/False`\n",
    "- device_map\n",
    "    - `device_map = {\"\": 0}`: 只使用 gpu 0，一张卡"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7e666ad",
   "metadata": {},
   "source": [
    "### lr scheduler"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbd6e2f8",
   "metadata": {},
   "source": [
    "- lr_scheduler_type\n",
    "    - warmup_steps\n",
    "    - warmup_ratio"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "354aeeba",
   "metadata": {},
   "source": [
    "```\n",
    "def get_warmup_steps(self, num_training_steps: int):\n",
    "    \"\"\"\n",
    "    Get number of steps used for a linear warmup.\n",
    "    \"\"\"\n",
    "    warmup_steps = (\n",
    "        self.warmup_steps if self.warmup_steps > 0 else math.ceil(num_training_steps * self.warmup_ratio)\n",
    "    )\n",
    "    return warmup_steps\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6e51799",
   "metadata": {},
   "source": [
    "## Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faede99e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T06:54:09.129144Z",
     "start_time": "2023-06-18T06:54:09.110032Z"
    }
   },
   "outputs": [],
   "source": [
    "Image('../imgs/trainer.png', width=500)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffaa1b95",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T03:00:37.290685Z",
     "start_time": "2023-06-18T03:00:37.279750Z"
    }
   },
   "source": [
    "- Trainer pipeline\n",
    "\n",
    "```\n",
    "train()\n",
    "    inner_training_loop()\n",
    "        for epoch in range(num_train_epochs):\n",
    "            for step, inputs in enumerate(epoch_iterator):\n",
    "                tr_loss_step = self.training_step(model, inputs)\n",
    "                    loss = self.compute_loss(model, inputs)\n",
    "                    loss.backward()\n",
    "```\n",
    "\n",
    "- `model` or `model_init`（Function object）\n",
    "    - 必须指定其一；\n",
    "- 核心的成员函数\n",
    "    - `compute_loss`: batch 粒度\n",
    "- 数据\n",
    "    - `train_dataset`\n",
    "    - `eval_dataset`\n",
    "- 参数：\n",
    "    - `args`\n",
    "- tokenzier\n",
    "    - `tokenizer`\n",
    "- 重要回调函数（非成员函数）\n",
    "    - `compute_metrics`：参数类型为 `EvalPrediction`\n",
    "    \n",
    "- datasets/inputs 的关键成员\n",
    "    - `labels`：Trainer looks for a column called labels "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed2c0ce2",
   "metadata": {},
   "source": [
    "## examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1e29e17",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T08:07:02.157737Z",
     "start_time": "2023-06-18T08:07:02.150920Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'\n",
    "os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5bab2ce",
   "metadata": {},
   "source": [
    "### datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "107d091a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T08:47:29.483478Z",
     "start_time": "2023-06-18T08:47:27.104893Z"
    }
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "raw_datasets = load_dataset('glue', 'mrpc')\n",
    "ckpt = 'bert-base-uncased'\n",
    "tokenizer = AutoTokenizer.from_pretrained(ckpt)\n",
    "\n",
    "def tokenize_func(examples):\n",
    "    return tokenizer(examples['sentence1'], examples['sentence2'], truncation=True)\n",
    "\n",
    "tokenized_datasets = raw_datasets.map(tokenize_func, batched=True)\n",
    "data_collator = DataCollatorWithPadding(tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83f80d47",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T08:01:11.235019Z",
     "start_time": "2023-06-18T08:01:11.225709Z"
    }
   },
   "outputs": [],
   "source": [
    "raw_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "360e65e4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T07:13:17.181606Z",
     "start_time": "2023-06-18T07:13:17.172671Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "raw_datasets['train'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "682e29b3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T07:55:29.947787Z",
     "start_time": "2023-06-18T07:55:29.912228Z"
    }
   },
   "outputs": [],
   "source": [
    "input_ids = tokenizer(raw_datasets['validation']['sentence1'], \n",
    "                      raw_datasets['validation']['sentence2'], \n",
    "                      truncation=True)['input_ids']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47249914",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T07:55:31.501736Z",
     "start_time": "2023-06-18T07:55:31.491898Z"
    }
   },
   "outputs": [],
   "source": [
    "len(input_ids[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44194e76",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T07:56:34.071656Z",
     "start_time": "2023-06-18T07:56:34.062640Z"
    }
   },
   "outputs": [],
   "source": [
    "tokenized_datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b051736f",
   "metadata": {},
   "source": [
    "### model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "48b2433b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T08:47:38.652241Z",
     "start_time": "2023-06-18T08:47:36.809957Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "BertForSequenceClassification(\n",
       "  (bert): BertModel(\n",
       "    (embeddings): BertEmbeddings(\n",
       "      (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
       "      (position_embeddings): Embedding(512, 768)\n",
       "      (token_type_embeddings): Embedding(2, 768)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (encoder): BertEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0-11): 12 x BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (pooler): BertPooler(\n",
       "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "      (activation): Tanh()\n",
       "    )\n",
       "  )\n",
       "  (dropout): Dropout(p=0.1, inplace=False)\n",
       "  (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "model = AutoModelForSequenceClassification.from_pretrained(ckpt, num_labels=2)\n",
    "# model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d8a02d4",
   "metadata": {},
   "source": [
    "### training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2e882e4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T08:15:42.355043Z",
     "start_time": "2023-06-18T08:15:42.343277Z"
    }
   },
   "outputs": [],
   "source": [
    "tokenized_datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8405eac",
   "metadata": {},
   "source": [
    "- `TrainingArguments`: `@dataclass`\n",
    "    - 只是用来存参数配置的；\n",
    "    - batch_size 与 global_step \n",
    "        - per_device_train_batch_size=16,\n",
    "            - per_device_eval_batch_size=16,\n",
    "        - num_train_epochs=5,\n",
    "        - `5*3668 /(16*2) == 574`\n",
    "            - 2 表示我本机 gpus 的数量\n",
    "    - `evaluation_strategy`\n",
    "        - `epoch`\n",
    "        - `steps`\n",
    "    - `logging_strategy`：如果不指定的话，输出的 log 显示上 `Training Loss`（no log）\n",
    "        - `epoch`\n",
    "    - 梯度优化相关\n",
    "        - `gradient_accumulation_steps`\n",
    "        \n",
    "- `Trainer`\n",
    "    - `data_collator`: \n",
    "        - `DataCollatorWithPadding(tokenizer)`\n",
    "            - dynamic padding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2714e761",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T08:16:39.137192Z",
     "start_time": "2023-06-18T08:16:39.126243Z"
    }
   },
   "outputs": [],
   "source": [
    "5*3668 /(16*2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a1f25c72",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T08:47:42.843238Z",
     "start_time": "2023-06-18T08:47:42.803630Z"
    }
   },
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    \"test-trainer\",\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=5,\n",
    "    learning_rate=2e-5,\n",
    "    weight_decay=0.01,\n",
    "    evaluation_strategy='epoch',\n",
    "    logging_strategy='epoch'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "60ccd4dc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T08:47:49.282600Z",
     "start_time": "2023-06-18T08:47:44.683746Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_4020411/1605526023.py:3: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n",
      "  metric = load_metric(\"glue\", \"mrpc\")\n",
      "Using the latest cached version of the module from /home/whaow/.cache/huggingface/modules/datasets_modules/metrics/glue/91f3cfc5498873918ecf119dbf806fb10815786c84f41b85a5d3c47c1519b343 (last modified on Sun Jun 18 16:07:08 2023) since it couldn't be found locally at glue, or remotely on the Hugging Face Hub.\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_metric\n",
    "import numpy as np\n",
    "metric = load_metric(\"glue\", \"mrpc\")\n",
    "\n",
    "def compute_metrics(eval_preds):\n",
    "    logits, labels = eval_preds\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return metric.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "dbb81e75",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T08:49:40.333368Z",
     "start_time": "2023-06-18T08:47:59.858475Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/whaow/anaconda3/lib/python3.10/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  warnings.warn(\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mlanchunhui\u001b[0m (\u001b[33mloveresearch\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "wandb version 0.15.4 is available!  To upgrade, please run:\n",
       " $ pip install wandb --upgrade"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.15.0"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/home/whaow/workspaces/bert_t5_gpt/tutorials/wandb/run-20230618_164803-tvyaf9f0</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/loveresearch/huggingface/runs/tvyaf9f0' target=\"_blank\">desert-haze-41</a></strong> to <a href='https://wandb.ai/loveresearch/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/loveresearch/huggingface' target=\"_blank\">https://wandb.ai/loveresearch/huggingface</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/loveresearch/huggingface/runs/tvyaf9f0' target=\"_blank\">https://wandb.ai/loveresearch/huggingface/runs/tvyaf9f0</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "/home/whaow/anaconda3/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='575' max='575' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [575/575 01:25, Epoch 5/5]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.545300</td>\n",
       "      <td>0.443570</td>\n",
       "      <td>0.816176</td>\n",
       "      <td>0.877651</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.350100</td>\n",
       "      <td>0.367913</td>\n",
       "      <td>0.830882</td>\n",
       "      <td>0.880829</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.206200</td>\n",
       "      <td>0.433944</td>\n",
       "      <td>0.840686</td>\n",
       "      <td>0.888889</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.117900</td>\n",
       "      <td>0.515514</td>\n",
       "      <td>0.833333</td>\n",
       "      <td>0.885522</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.076800</td>\n",
       "      <td>0.483391</td>\n",
       "      <td>0.848039</td>\n",
       "      <td>0.891228</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/whaow/anaconda3/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/whaow/anaconda3/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/whaow/anaconda3/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/whaow/anaconda3/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/home/whaow/anaconda3/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=575, training_loss=0.25926235033118206, metrics={'train_runtime': 98.4745, 'train_samples_per_second': 186.241, 'train_steps_per_second': 5.839, 'total_flos': 753299284826400.0, 'train_loss': 0.25926235033118206, 'epoch': 5.0})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import Trainer\n",
    "\n",
    "trainer = Trainer(\n",
    "    model,\n",
    "    training_args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    "    data_collator=data_collator,\n",
    "    tokenizer=tokenizer,\n",
    "    compute_metrics=compute_metrics\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9f0d6d8",
   "metadata": {},
   "source": [
    "### inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c721e4ab",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T08:43:28.278925Z",
     "start_time": "2023-06-18T08:43:27.642357Z"
    }
   },
   "outputs": [],
   "source": [
    "predictions = trainer.predict(tokenized_datasets['validation'])\n",
    "predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce082176",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-06-18T08:13:38.656189Z",
     "start_time": "2023-06-18T08:13:36.946061Z"
    }
   },
   "outputs": [],
   "source": [
    "from datasets import load_metric\n",
    "import numpy as np\n",
    "metric = load_metric(\"glue\", \"mrpc\")\n",
    "preds = np.argmax(predictions.predictions, axis=-1)\n",
    "metric.compute(predictions=preds, references=predictions.label_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "203e61b6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
